import os
import uuid
import types
from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

import bullet_safety_gym  # noqa
import dsrl
import gymnasium as gym  # noqa
import gym as gym_org
import numpy as np
import pyrallis
import torch
from dsrl.infos import DENSITY_CFG
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from fsrl.utils import WandbLogger, TensorboardLogger
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa

from examples.configs.offline_model_configs import OfflineModelTrainConfig, OfflineModel_DEFAULT_CONFIG
from osrl.algorithms import EnsembleDynamics, EnsembleDynamicsModel, EnsembleCostModel
from osrl.common import TransitionDataset
from osrl.common.exp_util import auto_name, seed_all
from osrl.common.model_logger import Logger
from osrl.common.net import StandardScaler, SimpleScaler, termination_fn_common


@pyrallis.wrap()
def train(args: OfflineModelTrainConfig):
    # update config
    cfg, old_cfg = asdict(args), asdict(OfflineModelTrainConfig())
    differing_values = {key: cfg[key] for key in cfg.keys() if cfg[key] != old_cfg[key]}
    cfg = asdict(OfflineModel_DEFAULT_CONFIG[args.task]())
    cfg.update(differing_values)
    args = types.SimpleNamespace(**cfg)

    # setup logger
    default_cfg = asdict(OfflineModel_DEFAULT_CONFIG[args.task]())
    if args.name is None:
        args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
    if args.group is None:
        args.group = args.task
    if args.logdir is not None:
        args.logdir = os.path.join(args.logdir, args.group, args.name)
    # logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
    logger = TensorboardLogger(args.logdir, log_txt=True, name=args.name)
    logger.save_config(cfg, verbose=args.verbose)
    output_config = {
        "consoleout_backup": "stdout",
        "policy_training_progress": "csv",
        "dynamics_training_progress": "csv",
        "tb": "tensorboard"
    }
    model_logger =Logger(logger.log_dir, output_config)

    # set seed
    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    # initialize environment
    if "Metadrive" in args.task:
        # import gym
        env = gym_org.make(args.task)
    else:
        env = gym.make(args.task)

    # pre-process offline dataset
    data = env.get_dataset()
    env.set_target_cost(args.cost_limit)

    cbins, rbins, max_npb, min_npb = None, None, None, None
    if args.density != 1.0:
        density_cfg = DENSITY_CFG[args.task + "_density" + str(args.density)]
        cbins = density_cfg["cbins"]
        rbins = density_cfg["rbins"]
        max_npb = density_cfg["max_npb"]
        min_npb = density_cfg["min_npb"]
    data = env.pre_process_data(data,
                                args.outliers_percent,
                                args.noise_scale,
                                args.inpaint_ranges,
                                args.epsilon,
                                args.density,
                                cbins=cbins,
                                rbins=rbins,
                                max_npb=max_npb,
                                min_npb=min_npb)
    if args.safe_only:
        idx = (data["costs"]==0)
        for key in data.keys():
            data[key] = data[key][idx]
    # print(idx.shape)
    # print(data["observations"].shape)
    
    # assert False
    # 记得恢复 logger

    # wrapper
    env = wrap_env(
        env=env,
        reward_scale=args.reward_scale,
    )
    env = OfflineEnvWrapper(env)

    # model & optimizer setup
    dynamics_model = EnsembleDynamicsModel(
        obs_dim=env.observation_space.shape[0],
        action_dim=env.action_space.shape[0],
        hidden_dims=args.dynamic_hidden_dims,
        num_ensemble=args.num_ensemble,
        num_elites=args.num_elites,
        weight_decays=args.dynamic_weight_decays,
        with_cost=args.with_cost,
        device=args.device
    )
    cost_model = EnsembleCostModel(
        obs_dim=env.observation_space.shape[0],
        action_dim=env.action_space.shape[0],
        hidden_dims=args.cost_model_hidden_dims,
        num_ensemble=args.num_ensemble,
        num_elites=args.num_elites,
        weight_decays=args.dynamic_weight_decays,
        device=args.device
    )
    print(f"Total parameters: {sum(p.numel() for p in dynamics_model.parameters())}")
    dynamics_optim = torch.optim.Adam(
        dynamics_model.parameters(),
        lr=args.learning_rate
    )
    cost_model_optim = torch.optim.Adam(
        cost_model.parameters(),
        lr=args.learning_rate
    )
    dynamics_scheduler = torch.optim.lr_scheduler.StepLR(dynamics_optim, step_size=args.decay_step, gamma=args.decay_rate)
    cost_model_scheduler = torch.optim.lr_scheduler.StepLR(cost_model_optim, step_size=args.decay_step, gamma=args.decay_rate)
    if args.simple_scaler:
        scaler = SimpleScaler()
    else:
         scaler = StandardScaler()
    termination_fn = termination_fn_common
    if args.safe_only:
        cost_model = None
        cost_model_optim = None
        cost_model_scheduler = None
    dynamics = EnsembleDynamics(
        dynamics_model,
        cost_model,
        dynamics_optim,
        cost_model_optim,
        scaler,
        termination_fn,
        use_scheduler=args.use_scheduler,
        dynamics_scheduler=dynamics_scheduler,
        cost_model_scheduler=cost_model_scheduler,
        penalty_coef=args.penalty_coef,
        with_cost=args.with_cost,
        use_delta_obs=args.use_delta_obs,
        reward_scale=args.reward_scale,
        cost_scale=args.cost_scale,
        cost_coef=args.cost_coef
    )
    dynamics.train(data, model_logger, batch_size=args.batch_size, max_epochs_since_update=args.max_epochs_since_update)


if __name__ == "__main__":
    train()
